package awesomeTask

import (
	awesomeTask "awesomeTask/config"
	"awesomeTask/system/helper"
	"awesomeTask/system/log"
	"database/sql"
	"fmt"
	"strconv"
	"strings"
)
import _ "github.com/go-sql-driver/mysql"

var query_all string = "QUERY_ALL"
var query_one string = "QUERY_ONE"
var ConnectionInstance map[string]*sql.DB = make(map[string]*sql.DB)

type Database struct {
	Host     string
	Port     string
	Driver   string
	User     string
	Password string
	Database string
}

func BeginTransaction(configOrDefault interface{}, fun func() bool) bool {
	db := getConnection(configOrDefault)
	tx, err := db.Begin()
	if err == nil {
		if fun() {
			tx.Commit()
			return true
		}
	}
	tx.Rollback()
	return false
}

/**
获取mysql连接
*/
func GetConnection(host string, port string, driver string, user string, password string, database string) (db *sql.DB) {
	uniqueId := driver + user + ":" + password + "@tcp(" + host + ":" + port + ")/" + database
	if ConnectionInstance[uniqueId] != nil {
		connection := ConnectionInstance[uniqueId]
		if connection.Ping() == nil {
			return connection
		}
	}
	db, err := sql.Open(driver, user+":"+password+"@tcp("+host+":"+port+")/"+database)
	if err != nil {
		panic(err.Error())
	}
	ConnectionInstance[uniqueId] = db
	return db
}
func GetConnectionByDatabaseName(database string) (db *sql.DB) {
	config := awesomeTask.GetConfigByKey("database.master").(map[string]interface{})
	return GetConnection(config["host"].(string), config["port"].(string), config["driver"].(string), config["user"].(string), config["password"].(string), database)
}
func getConnection(configOrDefault interface{}) (db *sql.DB) {
	switch configOrDefault.(type) {
	case *Database:
		config := configOrDefault.(*Database)
		return GetConnection(config.Host, config.Port, config.Driver, config.User, config.Password, config.Database)
	default:
		config := awesomeTask.GetConfigByKey("database.master").(map[string]interface{})
		db := GetConnection(config["host"].(string), config["port"].(string), config["driver"].(string), config["user"].(string), config["password"].(string), config["database"].(string))
		return db
	}
}
func Execute(sql string, args ...interface{}) (sql.Result, error) {
	return getConnection(false).Exec(sql, args...)
}
func buildQuery(queryString string, queryModel string, params ...interface{}) (interface{}, error) {
	db := getConnection(false)
	stmt, err := db.Prepare(queryString)
	if err != nil {
		list := make([]string, 0)
		for _, item := range params {
			list = append(list, item.(string))
		}
		queryString = strings.ReplaceAll(queryString, "?", "{}")
		panic(err.Error() + "==>sql:" + helper.Format(queryString, list...))
	}

	if queryModel == query_all {
		var rows *sql.Rows
		if len(params) == 1 {
			if Type(params[0]) == "list" {
				rows, err = stmt.Query(params[0].([]interface{})...)
			} else {
				rows, err = stmt.Query(params...)
			}
		} else {
			rows, err = stmt.Query(params...)
		}
		stmt.Close()
		return rows, err
	} else {
		var rows *sql.Row
		if len(params) == 1 {
			if Type(params[0]) == "list" {
				rows = stmt.QueryRow(params[0].([]interface{})...)
			} else {
				rows = stmt.QueryRow(params...)
			}
		} else {
			rows = stmt.QueryRow(params...)
		}
		stmt.Close()
		return rows, nil
	}

}
func Query(queryString string, params ...interface{}) []map[string]interface{} {
	rowData, err := buildQuery(queryString, query_all, params)
	rows := rowData.(*sql.Rows)
	if err != nil {
		list := make([]string, 0)
		for _, item := range params {
			list = append(list, item.(string))
		}
		queryString1 := strings.ReplaceAll(queryString, "?", "{}")
		panic(err.Error() + "queryString==>" + queryString + "==>sql:" + helper.Format(queryString1, list...))
	}
	columns, _ := rows.Columns()
	scanArgs := make([]interface{}, len(columns))
	values := make([]interface{}, len(columns))
	result := []map[string]interface{}{}
	for i := range values {
		scanArgs[i] = &values[i]
	}

	for rows.Next() {
		//将行数据保存到record字典
		err = rows.Scan(scanArgs...)
		record := make(map[string]interface{})

		for i, col := range values {
			if col != nil {
				if Type(col) != "int64" {
					record[columns[i]] = string(col.([]byte))
				} else {
					record[columns[i]] = col
				}

			}
		}
		result = append(result, record)
	}
	return result
}

func QueryOne(queryString string, params ...interface{}) map[string]interface{} {
	rowData, err := buildQuery(queryString, query_all, params)
	rows := rowData.(*sql.Rows)
	if err != nil {
		list := make([]string, 0)
		for _, item := range params {
			list = append(list, item.(string))
		}
		queryString1 := strings.ReplaceAll(queryString, "?", "{}")
		panic(err.Error() + "queryString==>" + queryString + "==>sql:" + helper.Format(queryString1, list...))
	}
	columns, _ := rows.Columns()
	scanArgs := make([]interface{}, len(columns))
	values := make([]interface{}, len(columns))
	for i := range values {
		scanArgs[i] = &values[i]
	}

	for rows.Next() {
		//将行数据保存到record字典
		err = rows.Scan(scanArgs...)
		record := make(map[string]interface{})

		for i, col := range values {
			if col != nil {
				if Type(col) != "int64" {
					record[columns[i]] = string(col.([]byte))
				} else {
					record[columns[i]] = col
				}
			}
		}
		return record
	}
	return nil
}

//func QueryOne(queryString string, params ...interface{}) map[string]interface{} {
//	db := getConnection(false)
//	stmt, err := db.Prepare(queryString)
//	if err != nil {
//		list := make([]string, 0)
//		for _, item := range params {
//			list = append(list, item.(string))
//		}
//		queryString = strings.ReplaceAll(queryString, "?", "{}")
//		panic(err.Error() + "==>sql:" + helper.Format(queryString, list...))
//	}
//	var rows *sql.Row
//	if len(params) == 1 {
//		if Type(params[0]) == "list" {
//			rows= stmt.QueryRow(params[0].([]interface{})...)
//		} else {
//			rows = stmt.QueryRow(params...)
//		}
//	} else {
//		rows= stmt.QueryRow(params...)
//	}
//	if err != nil {
//		list := make([]string, 0)
//		for _, item := range params {
//			list = append(list, item.(string))
//		}
//		queryString1 := strings.ReplaceAll(queryString, "?", "{}")
//		panic(err.Error() + "queryString==>" + queryString + "==>sql:" + helper.Format(queryString1, list...))
//	}
//	return rows.Scan()
//}
func Update(table string, updateMap map[string]interface{}, whereString string, whereParams ...interface{}) int64 {
	updateList := make([]interface{}, 0)
	updateString := "UPDATE " + table + " set "
	for key, value := range updateMap {
		updateString += key + "=" + "?,"
		updateList = append(updateList, value)
	}
	if len(updateList) > 0 {
		updateString = updateString[0 : len(updateString)-1]
	}
	for _, item := range whereParams {
		updateList = append(updateList, item)
	}
	updateString += " " + whereString
	stmt, err := getConnection(false).Prepare(updateString)
	if err != nil {
		panic(err.Error() + "sql:" + updateString)
	}
	n, err := stmt.Exec(updateList...)
	if err != nil {
		fmt.Println(updateList)
		panic(err.Error() + "sql:" + updateString)
	}
	num, err := n.RowsAffected()
	return num
}
func Insert(tableName string, insertFileds map[string]interface{}) int64 {
	insertString := "INSERT INTO " + tableName
	insertListString := "("
	insertValueString := "("
	insertList := []interface{}{}
	for key, value := range insertFileds {
		insertListString += key + ","
		insertList = append(insertList, value)
		insertValueString += "?,"
	}
	insertListString = insertListString[0:len(insertListString)-1] + ")"
	insertValueString = insertValueString[0:len(insertValueString)-1] + ")"
	insertStringRes := insertString + insertListString + " VALUES " + insertValueString
	stmt, err := getConnection(false).Prepare(insertStringRes)
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return 0
	}
	res, err := stmt.Exec(insertList...)
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return 0
	}
	id, err := res.LastInsertId()
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return 0
	}

	return id
}
func InsertOrUpdate(tableName string, insertFileds map[string]interface{}) bool {
	insertString := "INSERT INTO " + tableName
	insertListString := "("
	insertValueString := "("
	insertList := []interface{}{}
	updateString := ""
	for key, value := range insertFileds {
		insertListString += key + ","
		insertList = append(insertList, value)
		insertValueString += "?,"
		updateString += "`" + key + "`=values(" + "`" + key + "`),"
	}
	insertListString = insertListString[0:len(insertListString)-1] + ")"
	insertValueString = insertValueString[0:len(insertValueString)-1] + ")"
	insertStringRes := insertString + insertListString + " VALUES " + insertValueString
	insertStringRes += " ON DUPLICATE KEY UPDATE "
	insertStringRes += updateString
	insertStringRes = insertStringRes[0 : len(insertStringRes)-1]
	stmt, err := getConnection(false).Prepare(insertStringRes)
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return false
	}
	res, err := stmt.Exec(insertList...)
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return false
	}
	_, err = res.RowsAffected()
	res.LastInsertId()
	if err != nil {
		log.GetLogger().Error("插入语句执行错误:" + " error:" + err.Error() + " sql:" + insertStringRes)
		return false
	}
	return true
}
func WhereIn(list []interface{}) (string, []interface{}) {
	if len(list) == 0 {
		list = append(list, "empty set"+strconv.FormatInt(NowInt(), 10))
	}
	res := ""
	resList := make([]interface{}, len(list))
	for index, item := range list {
		if Type(item) == "int64" {
			res += "?" + ","
			resList[index] = strconv.FormatInt(item.(int64), 10)
		} else {
			res += "?" + ","
			resList[index] = item.(string)
		}
	}
	if len(list) > 0 {
		return "(" + res[0:len(res)-1] + ")", resList
	}
	return "(" + "" + ")", resList
}
func Type(i interface{}) string {
	switch i.(type) {
	case string:
		return "string"
	case int64:
		return "int64"
	case bool:
		return "bool"
	case uint8:
		return "uit8"
	case []interface{}:
		return "list"
	default:
		return ""
	}
}
func Pager(queryString string, page int, pageSize int, params ...interface{}) map[string]interface{} {
	totalSql := "select count(1) as num from (" + queryString + ") as tmp"
	totalRow := QueryOne(totalSql, params...)
	total := totalRow["num"]
	listSql := "select tmp.* from (" + queryString + ") as tmp limit " + strconv.Itoa(pageSize) + " offset " + strconv.Itoa(pageSize*(page-1))
	list := Query(listSql, params...)
	return map[string]interface{}{
		"list":     list,
		"total":    total,
		"page":     page,
		"pageSize": pageSize,
	}
}
